{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Validation of GPT2 traied on TinyStories\n",
    "\n",
    "Load the model and tokenizer\n",
    "Model checkpoint from training `gpt2_tinystories.py` : `gpt2-tinystories-final`\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model loaded on:  cuda\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from transformers import (\n",
    "    GPT2LMHeadModel,\n",
    "    GPT2TokenizerFast,\n",
    ")\n",
    "\n",
    "# Load tokenizer\n",
    "tokenizer = GPT2TokenizerFast.from_pretrained('gpt2')\n",
    "tokenizer.pad_token = tokenizer.eos_token\n",
    "\n",
    "# Load model and move to GPU\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "model = GPT2LMHeadModel.from_pretrained('gpt2-tinystories-fr-scratch-final')\n",
    "model.to(device)\n",
    "model.eval()\n",
    "\n",
    "print(\"Model loaded on: \", device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Generate sample stories given prompts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Testing model generation:\n",
      "\n",
      "Prompt: Once upon a time\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generated text: Once upon a time, there was a little girl named Lily. She loved to play in the park with her friends. One day, Lily's mom took her to the store to buy some milk. While they were there, a big dog came running towards them. Lily got scared and started to cry. \n",
      "\n",
      "Her mom told her not to worry and that they would be safe in their house. They bought some ice cream and went back to their car. When they got home, they sat\n",
      "\n",
      "Prompt: The little dog\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generated text: The little dog was walking down the street. He was so excited! He had never been so close to the playground before.\n",
      "\n",
      "As he walked, he saw a tall slide. It looked so fun! The dog ran to it and started to climb. His little legs took him up the slide and he felt like he was flying!\n",
      " \n",
      "Suddenly, the little boy saw the dog and shouted, \"Hey, that's my slide! Give it back!\"\n",
      " The little kid was angry\n",
      "\n",
      "Prompt: In the garden\n",
      "Generated text: In the garden, there was a big, red wagon. It was so big that it could carry lots of things. Little Jane was playing with her toys when she saw something shiny in the corner. She went over to take a look and saw it was an onion.\n",
      "\n",
      "Jane was very excited and she said to her mom, \"Mommy, look! An onion!\" Her mom smiled and said, â€œThatâ€™s an important onion, Jane. Letâ\n"
     ]
    }
   ],
   "source": [
    "# Test generation function\n",
    "def generate_text(prompt, max_length=100):\n",
    "    inputs = tokenizer(prompt, return_tensors=\"pt\").to(model.device)\n",
    "    attention_mask = inputs['attention_mask'].to(model.device)\n",
    "    outputs = model.generate(\n",
    "        inputs.input_ids,\n",
    "        attention_mask=attention_mask,\n",
    "        max_length=max_length,\n",
    "        num_return_sequences=1,\n",
    "        no_repeat_ngram_size=2,\n",
    "        do_sample=True,\n",
    "        top_k=50,\n",
    "        top_p=0.95,\n",
    "        temperature=0.7\n",
    "    )\n",
    "    return tokenizer.decode(outputs[0], skip_special_tokens=True)\n",
    "\n",
    "# Test the model\n",
    "test_prompts = [\n",
    "    \"Once upon a time\",\n",
    "    \"The little dog\",\n",
    "    \"In the garden\"\n",
    "]\n",
    "\n",
    "print(\"\\nTesting model generation:\")\n",
    "for prompt in test_prompts:\n",
    "    print(f\"\\nPrompt: {prompt}\")\n",
    "    print(f\"Generated text: {generate_text(prompt)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Use an interactive demo\n",
    "\n",
    "This will prompt the user for the beginning of a story and then generate the rest of the story.\n",
    "Type `exit` to quit the demo.\n",
    "\n",
    "```python"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generated text: At the river, there was a little girl named Lily. She was three years old. One day, she was walking near the edge of the lake when she saw something incredible. It was an enormous boat!\n",
      "\n",
      "Lily was so excited! She wanted to get closer and take a closer look. So, with one big push, Lily started to pull the boat closer to the shore.\n",
      " \n",
      "When she reached the beach, the sun was shining so brightly that it made the water spark\n"
     ]
    }
   ],
   "source": [
    "# make an interactive prompt that asks for user input\n",
    "def interactive_prompt():\n",
    "    while True:\n",
    "        prompt = input(\"\\nEnter a prompt: \")\n",
    "        if prompt.lower() == \"exit\":\n",
    "            break\n",
    "        print(f\"Generated text: {generate_text(prompt)}\")\n",
    "        \n",
    "interactive_prompt()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Evaluate the model on a validation set\n",
    "\n",
    "Using perplexity as the evaluation metric\n",
    "\n",
    "```python"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "347b809bab584a0fb911f1248f96e4ce",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Map (num_proc=4):   0%|          | 0/21990 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c5931585a4a84b0a831bb957df2f7f56",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "IntProgress(value=0, description='Evaluating', max=2749)"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Results on TinyStories validation set:\n",
      "Average Loss: 14.6298\n",
      "Perplexity: 2257549.5746\n",
      "Total tokens evaluated: 4675638\n"
     ]
    }
   ],
   "source": [
    "import math\n",
    "from transformers import DataCollatorForLanguageModeling\n",
    "from torch.utils.data import DataLoader\n",
    "from datasets import load_dataset\n",
    "from IPython.display import display\n",
    "import ipywidgets as widgets\n",
    "\n",
    "\n",
    "dataset = load_dataset(\"roneneldan/TinyStories\")\n",
    "\n",
    "def tokenize_function(examples):\n",
    "    \"\"\"Tokenize dataset examples.\"\"\"\n",
    "    return tokenizer(\n",
    "        examples[\"text\"],\n",
    "        truncation=True,\n",
    "        max_length=512,\n",
    "        padding=\"max_length\"\n",
    "    )\n",
    "\n",
    "# Tokenize test dataset\n",
    "tokenized_test = dataset[\"validation\"].map(\n",
    "    tokenize_function,\n",
    "    batched=True,\n",
    "    remove_columns=dataset[\"validation\"].column_names,\n",
    "    num_proc=4  # Parallel processing\n",
    ")\n",
    "\n",
    "# Create data collator\n",
    "data_collator = DataCollatorForLanguageModeling(\n",
    "    tokenizer=tokenizer,\n",
    "    mlm=False\n",
    ")\n",
    "\n",
    "# Create dataloader\n",
    "test_dataloader = DataLoader(\n",
    "    tokenized_test,\n",
    "    batch_size=8,\n",
    "    collate_fn=data_collator,\n",
    "    shuffle=False\n",
    ")\n",
    "\n",
    "# Initialize metrics\n",
    "total_loss = 0\n",
    "total_tokens = 0\n",
    "\n",
    "\n",
    "device = model.device\n",
    "\n",
    "progress_bar = widgets.IntProgress(min=0, max=len(test_dataloader), description='Evaluating')\n",
    "display(progress_bar)\n",
    "\n",
    "\n",
    "# Evaluation loop\n",
    "with torch.no_grad():\n",
    "    for i, batch in enumerate(test_dataloader):\n",
    "        # Prepare input\n",
    "        input_ids = batch['input_ids'].to(device)\n",
    "        attention_mask = batch['attention_mask'].to(device)\n",
    "        labels = input_ids.clone()\n",
    "        \n",
    "        # Forward pass\n",
    "        outputs = model(\n",
    "            input_ids=input_ids,\n",
    "            attention_mask=attention_mask,\n",
    "            labels=labels\n",
    "        )\n",
    "        \n",
    "        # Calculate loss\n",
    "        loss = outputs.loss\n",
    "        \n",
    "        # Count non-padding tokens\n",
    "        non_pad_mask = labels.ne(tokenizer.pad_token_id)\n",
    "        num_tokens = non_pad_mask.sum().item()\n",
    "        \n",
    "        # Accumulate loss and token count\n",
    "        total_loss += loss.item() * num_tokens\n",
    "        total_tokens += num_tokens\n",
    "        \n",
    "        progress_bar.value += 1\n",
    "\n",
    "# Calculate perplexity\n",
    "avg_loss = total_loss / total_tokens\n",
    "perplexity = math.exp(avg_loss)\n",
    "\n",
    "print(f\"\\nResults on TinyStories validation set:\")\n",
    "print(f\"Average Loss: {avg_loss:.4f}\")\n",
    "print(f\"Perplexity: {perplexity:.4f}\")\n",
    "print(f\"Total tokens evaluated: {total_tokens}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Text: Once upon a time\n",
      "Perplexity: 1.01\n",
      "\n",
      "Text: The little dog\n",
      "Perplexity: 38.56\n",
      "\n",
      "Text: In the garden\n",
      "Perplexity: 5.08\n"
     ]
    }
   ],
   "source": [
    "def calculate_perplexity(text, model, tokenizer, device=\"cuda\"):\n",
    "    # Encode text\n",
    "    encodings = tokenizer(\n",
    "        text,\n",
    "        return_tensors=\"pt\",\n",
    "        truncation=True,\n",
    "        max_length=512\n",
    "    )\n",
    "    \n",
    "    # Move to device\n",
    "    input_ids = encodings[\"input_ids\"].to(device)\n",
    "    \n",
    "    # Calculate perplexity\n",
    "    with torch.no_grad():\n",
    "        outputs = model(input_ids, labels=input_ids)\n",
    "        loss = outputs.loss\n",
    "    \n",
    "    return math.exp(loss.item())\n",
    "\n",
    "# Test the model\n",
    "texts = [\n",
    "    \"Once upon a time\",\n",
    "    \"The little dog\",\n",
    "    \"In the garden\"\n",
    "]\n",
    "\n",
    "# Calculate perplexity for each text\n",
    "for text in texts:\n",
    "    perplexity = calculate_perplexity(text, model, tokenizer, device)\n",
    "    print(f\"\\nText: {text}\")\n",
    "    print(f\"Perplexity: {perplexity:.2f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Average Loss: 10.3397\n",
      "Perplexity: 86240.4700\n"
     ]
    }
   ],
   "source": [
    "# Evaluation loop\n",
    "total_loss = 0\n",
    "total_perplexity = 0\n",
    "counter = 0\n",
    "with torch.no_grad():\n",
    "    for i, batch in enumerate(test_dataloader):\n",
    "        input_ids = batch['input_ids'].to(device)\n",
    "        attention_mask = batch['attention_mask'].to(device) \n",
    "        labels = input_ids.clone()\n",
    "        \n",
    "        # Calculate perplexity\n",
    "        with torch.no_grad():\n",
    "            outputs = model(input_ids, labels=input_ids)\n",
    "            loss = outputs.loss\n",
    "            total_loss += loss.item()\n",
    "            total_perplexity += math.exp(loss.item())\n",
    "            counter += 1\n",
    "            \n",
    "        progress_bar.value += 1\n",
    "       \n",
    "# Calculate average loss and perplexity\n",
    "avg_loss = total_loss / counter\n",
    "perplexity = total_perplexity / counter\n",
    "\n",
    "print(f\"Average Loss: {avg_loss:.4f}\")\n",
    "print(f\"Perplexity: {perplexity:.4f}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "app",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
